Skip to content

OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852

Merged
kevalmorabia97 merged 11 commits intomainfrom
ajrasane/onnx_qdq
Mar 17, 2026
Merged

OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852
kevalmorabia97 merged 11 commits intomainfrom
ajrasane/onnx_qdq

Conversation

@ajrasane
Copy link
Copy Markdown
Contributor

@ajrasane ajrasane commented Feb 4, 2026

What does this PR do?

Type of change:
New feature

Overview:

  • Updated FP8 quant exporter to replace modelopt custom QDQ nodes with native ONNX QDQ nodes
  • Updated get_onnx_bytes_and_metadata to make convert_float_to_float16() default instead of autocast
  • Created util functions to fix graph structure after conversion

Testing

python torch_quant_to_onnx.py --quantize_mode=fp8 \
	--onnx_save_path=<model_path> \
	--calibration_data_size 64 \
	--batch_size 128

python evaluate.py --onnx_path=<model_path> \
	--model_name=vit_base_patch16_224 \
	--results_path=./results.txt \
	--batch_size 128

Results:
Before replacement:

The top1 accuracy of the model is 85.06%
The top5 accuracy of the model is 97.558%
Inference latency of the model is 5.27963 ms

After replacement:

The top1 accuracy of the model is 85.054%
The top5 accuracy of the model is 97.542%
Inference latency of the model is 5.74771 ms

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: No
  • Replaced modelopt QDQ nodes with native ONNX qdq nodes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Summary by CodeRabbit

  • New Features

    • ONNX utilities to remove redundant Casts, fold Constant→Cast patterns, and convert targeted Casts to FP16.
  • Improvements

    • FP8 QDQ nodes now converted to native ONNX QDQ/Dequantize nodes for improved compatibility.
    • Export pipeline streamlined: consistent FP16 handling, unified weight quantization, cast cleanup ordering, and added logging for better traceability.
  • Tests

    • Unit tests updated to use the new ONNX utilities.
  • Changelog

    • Entry added noting FP8 QDQ → native ONNX QDQ conversion.

@ajrasane ajrasane requested review from a team as code owners February 4, 2026 01:08
@ajrasane ajrasane requested a review from cjluo-nv February 4, 2026 01:08
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 4, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds FP8 post-processing converting TensorRT-specific FP8 QDQ nodes to native ONNX QDQ, introduces ONNX Cast utilities for redundant-cast removal and targeted FP16 conversion, and updates Torch ONNX export to apply these utilities with a reordered FP16/quantization pipeline.

Changes

Cohort / File(s) Summary
FP8 Export Post-Processing
modelopt/onnx/export/fp8_exporter.py
Replaced a no-op post_process with logic converting TRT_FP8QuantizeLinearQuantizeLinear (adds FP8 zero_point if missing, sets saturate) and TRT_FP8DequantizeLinearDequantizeLinear; runs graph cleanup/toposort, re-exports ONNX, and adds logging. Updated compress_weights docs/comments.
ONNX Cast Utilities
modelopt/onnx/utils.py
Added utilities to read Cast target types, get producer/consumer nodes, stash/replace tensor names, detect/fold redundant casts (same-type, sequential, Constant→Cast), bypass/rewrite connections, fold Constant→Cast, plus remove_redundant_casts(onnx.ModelProto) and change_casts_to_fp16(model, target_op_types). Minor tweak to randomize_weights_onnx_bytes metadata access.
Precision Converter (autocast)
modelopt/onnx/autocast/precisionconverter.py
Removed internal cast-management helpers and delegated those responsibilities to centralized onnx_utils functions (producer/consumer lookups, bypassing, type queries, redundant-cast removal). Updated flows to call onnx_utils.remove_redundant_casts(self.model) and related helpers.
Autocast Utilities / GraphSanitizer
modelopt/onnx/autocast/utils.py, modelopt/onnx/autocast/graphsanitizer.py
Replaced local producer/consumer/cast helpers with calls to onnx_utils (removed get_consumer_nodes, get_producer_nodes, get_cast_to_type from autocast utils). GraphSanitizer updated to use onnx_utils equivalents.
Torch ONNX Integration
modelopt/torch/_deploy/utils/torch_onnx.py
Imported and exposed change_casts_to_fp16 and remove_redundant_casts, patched onnxconverter_common.remove_unnecessary_cast_node with a suppress wrapper, and reordered FP16/quantization pipeline: always quantize weights, apply FP16 weight conversion, convert FP32 Casts feeding Concat/Add to FP16, then run redundant-cast removal prior to IR/external-data handling.
NVFP4 Exporter minor tweak
modelopt/onnx/export/nvfp4_exporter.py
Replaced dict.get(..., None) with dict.get(...) for initializer lookups in three locations (stylistic; behavior unchanged).
Tests
tests/unit/onnx/autocast/test_precisionconverter.py
Updated test assertions to use onnx_utils.get_consumer_nodes instead of removed local utils.get_consumer_nodes.
Changelog
CHANGELOG.rst
Added entry noting modelopt FP8 QDQ nodes are replaced with native ONNX QDQ nodes (0.43).

Sequence Diagram(s)

sequenceDiagram
    participant Exporter as FP8 Exporter
    participant Graph as ONNX Graph
    participant TRT as TRT_FP8 Nodes
    participant Native as Native ONNX Ops
    participant Cleaner as Graph Cleaner
    Exporter->>Graph: scan for TRT_FP8QuantizeLinear / TRT_FP8DequantizeLinear
    Graph->>TRT: identify TRT_FP8QuantizeLinear nodes
    Exporter->>Graph: for each TRT_FP8QuantizeLinear -> create `zero_point` const if missing
    Exporter->>Graph: replace TRT_FP8QuantizeLinear with QuantizeLinear (set saturate)
    Graph->>TRT: identify TRT_FP8DequantizeLinear nodes
    Exporter->>Graph: replace TRT_FP8DequantizeLinear with DequantizeLinear
    Exporter->>Cleaner: invoke cleanup & topological sort
    Cleaner->>Graph: remove unused nodes, fix edges, toposort
    Cleaner->>Native: graph now uses native ONNX QDQ nodes
    Exporter->>Exporter: export cleaned ONNX model
    Note over Exporter,Cleaner: logger.info/debug traces conversions
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.68% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: replacing Model-Optimizer's custom FP8 QDQ nodes with native ONNX QDQ nodes, which aligns perfectly with the PR's primary objective and the majority of changes across multiple files.
Security Anti-Patterns ✅ Passed No security anti-patterns found: torch.load(), numpy.load(), trust_remote_code=True, eval/exec, # nosec comments, or new non-permissive dependencies in modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ajrasane/onnx_qdq
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 104-140: The post_process function's docstring mentions updating
GELU nodes to tanh approximation and inserting Cast nodes after Sqrt, but the
implementation in post_process only converts
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to
QuantizeLinear/DequantizeLinear; either remove or revise those docstring lines
to reflect current behavior, or implement the missing steps: locate GELU nodes
in graph.nodes and replace/modify them to the tanh-approx variant, and insert
Cast nodes immediately after Sqrt nodes' outputs; reference post_process,
TRT_FP8QuantizeLinear, TRT_FP8DequantizeLinear, GELU, and Sqrt when making the
change.
- Around line 119-126: The FP8 zero-point tensor zp_tensor is missing explicit
shape metadata; update the creation of zp_tensor (used to build zero_point and
appended to node.inputs) to set its dims explicitly (e.g., call
zp_tensor.dims.extend([1]) for a 1-element tensor) so it matches other tensors
created in this module (see the FP8 weights tensor creation) and ensures ONNX
runtimes receive shape info.

In `@modelopt/onnx/utils.py`:
- Around line 1314-1349: In change_casts_to_fp16, only modify Cast nodes that
actually cast from FP32: for each Cast node (node.op_type == "Cast") look up the
source tensor name node.input[0] in graph.initializer, graph.input,
graph.value_info or graph.output to get its element_type and only change the
node.attribute "to" from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 if
the source dtype is FLOAT; also avoid changing Casts that are FP16->FP32 and add
a debug log entry when you modify a Cast (include node.name or node.output[0]
and original->new dtypes) to aid debugging.
🧹 Nitpick comments (1)
modelopt/onnx/utils.py (1)

1218-1261: Consider edge case where first Cast has multiple consumers.

The function checks len(node.outputs[0].outputs) != 1 (line 1231) to ensure the first Cast's output goes to exactly one node. However, this may be overly restrictive. If the first Cast feeds into a duplicate second Cast AND other nodes, you could still remove the duplicate Cast while preserving the connection to other consumers. The current logic skips this optimization opportunity.

This is a minor optimization opportunity and the current implementation is safe.

Comment thread modelopt/onnx/export/fp8_exporter.py
Comment thread modelopt/onnx/export/fp8_exporter.py
Comment thread modelopt/onnx/utils.py
@codecov
Copy link
Copy Markdown

codecov bot commented Feb 4, 2026

Codecov Report

❌ Patch coverage is 77.93427% with 47 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.06%. Comparing base (58417e5) to head (de464f9).
⚠️ Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/utils.py 78.69% 36 Missing ⚠️
modelopt/torch/_deploy/utils/torch_onnx.py 63.63% 8 Missing ⚠️
modelopt/onnx/autocast/precisionconverter.py 78.57% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #852      +/-   ##
==========================================
- Coverage   70.07%   70.06%   -0.02%     
==========================================
  Files         221      221              
  Lines       25499    25603     +104     
==========================================
+ Hits        17869    17939      +70     
- Misses       7630     7664      +34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gcunhase
Copy link
Copy Markdown
Contributor

gcunhase commented Feb 9, 2026

@ajrasane can you please add the before and after accuracy results in the PR description? I.e: with FP8 custom Q/DQ nodes vs with FP8 native Q/DQ nodes. Thanks!

@gcunhase
Copy link
Copy Markdown
Contributor

gcunhase commented Feb 9, 2026

Let's also add this change to the Changelog file.

Comment thread modelopt/onnx/utils.py Outdated
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please elaborate the goal/need of this function? Thanks!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because after the convert_float_to_float16() function, one of the inputs for these nodes is FP16, while the other is FP32. Hence we run into a compilation issue with TensorRT. To fix this, I manually update them here for these operators.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the explanation. Can you please update the docstring to give a bit more details? Thanks!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run into an error while building the engine with TensorRT:

[03/12/2026-21:06:15] [E] Error[9]: Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [myelin_graph.h:attachExceptionMsgToGraph:1146] MyelinCheckException: operand.h:456: CHECK(is_tensor()) failed.  In compileGraph at optimizer/myelin/codeGenerator.cpp:1421
[03/12/2026-21:06:15] [E] Error[10]: IBuilder::buildSerializedNetworkToStream: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[x_cast_to_fp16...(Unnamed Layer* 1752) [ElementWise]]}. In computeCosts at optimizer/common/tactic/optimizer.cpp:4115)

I also remember that you had previously mentioned that autocast is not supposed to be used after quantization as it would need a separate design. Hence I removed it from here. Let me know if that is no more the case.

Comment thread modelopt/onnx/export/fp8_exporter.py
@ajrasane ajrasane requested a review from a team as a code owner February 13, 2026 14:09
@ajrasane ajrasane requested a review from galagam February 13, 2026 14:09
@gcunhase
Copy link
Copy Markdown
Contributor

gcunhase commented Feb 13, 2026

5.74771 ms

Accuracy looks good, any idea why perf is slower after this PR?

Also, can you please specify which model these numbers are for?

Thanks.

Copy link
Copy Markdown
Contributor

@gcunhase gcunhase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @galagam are you okay with making the redundant casts function a utils function? Thanks!

Copy link
Copy Markdown
Contributor

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?

Comment thread modelopt/onnx/utils.py Outdated
logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}")

if removed_count > 0:
graph.cleanup().toposort()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall some issues with toposort.
If you see any failures do to it, we can probably omit, _bypass_cast maintains node sorting.

@galagam
Copy link
Copy Markdown
Contributor

galagam commented Feb 15, 2026

LGTM, @galagam are you okay with making the redundant casts function a utils function? Thanks!

AutoCast's unit testing covers this part well, and indeed, I see there's quite a few failures with this refactor.
Approved the general concept, but need to make sure we don't cause regressions/behavior changes for AutoCast.
Thanks.
@gcunhase @ajrasane

@ajrasane ajrasane force-pushed the ajrasane/onnx_qdq branch from 0186223 to 788313f Compare March 12, 2026 18:37
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

1151-1172: ⚠️ Potential issue | 🟠 Major

This shared cast cleanup can undo output-name preservation.

_cleanup() already fixes network output names before this call, but onnx_utils.remove_redundant_casts() bypasses output casts by replacing graph.outputs with the cast input. For redundant casts on model outputs, that reverts the exported output tensor name to the pre-cast name and can trip _sanity_check() or break the public I/O contract.

🧹 Nitpick comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)

62-71: Scope this onnxconverter_common workaround to the conversion call.

Patching the module at import time changes behavior process-wide, and suppress(AttributeError) hides every upstream AttributeError, not just the known list/attr bug. A temporary patch around convert_float_to_float16() is much safer, and it avoids making this module import brittle if the upstream symbol changes.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 62 - 71, The current
import-time monkey-patch of _f16_module.remove_unnecessary_cast_node (using
_original_remove_unnecessary_cast_node and
_patched_remove_unnecessary_cast_node) is global and hides all AttributeError
via suppress(AttributeError); instead, scope the workaround only around the call
to convert_float_to_float16(): before calling convert_float_to_float16() save
the original _f16_module.remove_unnecessary_cast_node, replace it with a minimal
wrapper that only catches the specific list/attribute error, call
convert_float_to_float16(), and finally restore the original function in a
try/finally so the patch is temporary and does not swallow unrelated
AttributeErrors or affect the rest of the process.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 121-147: In FP8QuantExporter.post_process(), before converting
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to native
QuantizeLinear/DequantizeLinear using FLOAT8E4M3FN and the saturate attribute,
validate the model opset version is >= 19; locate the method
FP8QuantExporter.post_process and check the graph/model opset (opset_import or
graph.model.opset_import) and if opset < 19 either raise a clear exception
(e.g., ValueError) telling callers to use onnx_opset >= 19 or programmatically
upgrade the model opset to 19 before performing the conversions (and then
proceed with the existing replacement logic for TRT_FP8QuantizeLinear and
TRT_FP8DequantizeLinear).

In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 576-599: The model_metadata is built too early and can become
stale after graph rewrites; move the metadata creation so it runs after all ONNX
mutations: after quantize_weights(), qdq_to_dq(), convert_float_to_float16(),
change_casts_to_fp16(), remove_redundant_casts(), and
replace_zero_scale_with_smallest_nonzero(), and ensure you rebuild it after
setting onnx_opt_graph.ir_version = 10 so the returned metadata matches the
final serialized graph bytes.
- Around line 581-588: The FP16 export path should not use torch.autocast during
tracing because you already perform explicit post-export conversion with
convert_float_to_float16; update the autocast logic so the
torch.autocast("cuda") context is only entered when weights_dtype == "bf16" (and
not when weights_dtype == "fp16"), i.e., change the condition that currently
enables autocast for weights_dtype != "fp32" to specifically check for "bf16"
and leave the FP16 path to rely solely on convert_float_to_float16; keep the
convert_float_to_float16 call for FP16 unchanged.

---

Nitpick comments:
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 62-71: The current import-time monkey-patch of
_f16_module.remove_unnecessary_cast_node (using
_original_remove_unnecessary_cast_node and
_patched_remove_unnecessary_cast_node) is global and hides all AttributeError
via suppress(AttributeError); instead, scope the workaround only around the call
to convert_float_to_float16(): before calling convert_float_to_float16() save
the original _f16_module.remove_unnecessary_cast_node, replace it with a minimal
wrapper that only catches the specific list/attribute error, call
convert_float_to_float16(), and finally restore the original function in a
try/finally so the patch is temporary and does not swallow unrelated
AttributeErrors or affect the rest of the process.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a86e0e09-2562-47a8-9985-98fa922fd2f0

📥 Commits

Reviewing files that changed from the base of the PR and between 2ebf0a2 and 788313fa4a998c8f18315b736310f700b29a6a21.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • modelopt/onnx/autocast/precisionconverter.py
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/export/nvfp4_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/utils.py

Comment on lines +121 to +147
# Convert TRT_FP8QuantizeLinear to native QuantizeLinear
for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
node.op = "QuantizeLinear"
# Add FP8 zero_point if not present
if len(node.inputs) == 2:
# Create FP8 zero point constant
zp_tensor = onnx.TensorProto()
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
zp_tensor.dims.extend([1]) # 1-element tensor
zp_tensor.raw_data = b"\x00" # Zero in FP8
zp_values = LazyValues(zp_tensor)
zero_point = gs.Constant(node.name + "_zero_point", zp_values)
node.inputs.append(zero_point)
# Add saturate attribute for FP8
node.attrs["saturate"] = 1
logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear")

# Convert TRT_FP8DequantizeLinear to native DequantizeLinear
for node in graph.nodes:
if node.op == "TRT_FP8DequantizeLinear":
node.op = "DequantizeLinear"
logger.debug(
f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear"
)

graph.cleanup().toposort()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Using the official ONNX operator and type documentation, what is the minimum ai.onnx opset that supports FLOAT8E4M3FN with QuantizeLinear and DequantizeLinear, including the saturate attribute on QuantizeLinear?

💡 Result:

Minimum ai.onnx opset: 19.

  • QuantizeLinear first supports tensor(float8e4m3fn) (via its T2 type constraint) in since_version = 19, and this is also the version that defines the saturate attribute (applying to float8 quantization). [1]
  • DequantizeLinear first supports tensor(float8e4m3fn) (via its T1 type constraint) in since_version = 19. [2]

Sources:
[1] QuantizeLinear (ONNX operator doc, version 19) (onnx.ai)
[2] DequantizeLinear (ONNX operator doc, version 19) (onnx.ai)

Citations:


🏁 Script executed:

# Check the fp8_exporter.py file to understand the full context
# Look for the function containing this code and any opset checks
fd "fp8_exporter.py" -x wc -l {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 153


🏁 Script executed:

# Read the file to understand the function context and any opset validation
cat -n modelopt/onnx/export/fp8_exporter.py | head -200

Repository: NVIDIA/Model-Optimizer

Length of output: 7470


🏁 Script executed:

# Search for opset-related code and checks in the file
rg -n "opset|onnx_opset" modelopt/onnx/export/fp8_exporter.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Find the base exporter to understand the parent class
fd "base_exporter.py" modelopt/onnx/export

Repository: NVIDIA/Model-Optimizer

Length of output: 103


🏁 Script executed:

# Search for calls to post_process and get_onnx_bytes_and_metadata
rg -n "post_process|get_onnx_bytes_and_metadata" modelopt/onnx/export --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 942


🏁 Script executed:

# Check if there's opset validation in the base exporter
cat -n modelopt/onnx/export/base_exporter.py | head -300

Repository: NVIDIA/Model-Optimizer

Length of output: 2365


🏁 Script executed:

# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata|onnx_opset" modelopt/onnx --type py -B 2 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 823


🏁 Script executed:

# Look for where FP8QuantExporter is instantiated/used
rg -n "FP8QuantExporter" modelopt/onnx --type py -B 2 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1792


🏁 Script executed:

# Check the quantizer file that likely calls the exporter
fd "quantizer.py" modelopt/onnx -x wc -l {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search more broadly for onnx_opset and opset parameter usage
rg -n "onnx_opset|opset_version" modelopt/onnx --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 3244


🏁 Script executed:

# Look for quantizer or export related files that handle opset
fd "quantizer|export" modelopt/onnx --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 235


🏁 Script executed:

# Check the quantizer base class
rg -n "class.*Quantizer" modelopt/onnx --type py -A 10 | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the ort_patching.py file around the FP8 opset check
cat -n modelopt/onnx/quantization/ort_patching.py | sed -n '220,250p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1445


🏁 Script executed:

# Check where _check_opset_version is called and in what context
rg -n "_check_opset_version|check_opset_version" modelopt/onnx/quantization --type py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1253


🏁 Script executed:

# Look at the quantize.py file to understand the full flow
cat -n modelopt/onnx/quantization/quantize.py | sed -n '140,185p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2716


🏁 Script executed:

# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata" modelopt/onnx

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the quantize_mode constants to understand FP8 mode handling
rg -n "QDQ_PRECISION_MIN_OPSET|fp8" modelopt/onnx/quantization/quantize.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 3436


🏁 Script executed:

# Check if there's a separate export entry point that may bypass opset checks
rg -n "def.*export|get_onnx_bytes" modelopt/onnx --type py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 473


🏁 Script executed:

# Search for quantize_fp8 function
rg -n "def quantize_fp8|def quantize" modelopt/onnx/quantization/fp8.py -A 10 | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 604


🏁 Script executed:

# Look for where process_model is called
rg -n "process_model" modelopt/onnx --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 453


🏁 Script executed:

# Check the fp8.py file to see if it calls FP8QuantExporter.post_process
cat -n modelopt/onnx/quantization/fp8.py | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 7314


🏁 Script executed:

# Continue reading fp8.py to see where FP8QuantExporter is used
cat -n modelopt/onnx/quantization/fp8.py | sed -n '150,250p'

Repository: NVIDIA/Model-Optimizer

Length of output: 5087


🏁 Script executed:

# Search for where FP8QuantExporter.process_model is called
rg -n "\.process_model|FP8QuantExporter" modelopt/onnx --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1639


🏁 Script executed:

# Check for any test files or examples that use FP8 export
fd "test.*fp8|fp8.*test" modelopt --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 235


🏁 Script executed:

# Search for where process_model is called with FP8QuantExporter
rg -n "process_model\|FP8QuantExporter" modelopt/onnx

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look in llm_export_utils since it showed up in opset search
cat -n modelopt/onnx/llm_export_utils/export_utils.py | sed -n '155,170p'

Repository: NVIDIA/Model-Optimizer

Length of output: 379


🏁 Script executed:

# Check the entire quantize function end to see if it calls process_model or post_process
cat -n modelopt/onnx/quantization/fp8.py | tail -100

Repository: NVIDIA/Model-Optimizer

Length of output: 5585


🏁 Script executed:

# Search for get_onnx_bytes_and_metadata more carefully
rg -n "get_onnx_bytes_and_metadata\|get_onnx_bytes"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for uses of FP8QuantExporter more comprehensively
rg -rn "FP8QuantExporter" modelopt/

Repository: NVIDIA/Model-Optimizer

Length of output: 667


🏁 Script executed:

# Check if there's an export function that uses the exporter classes
rg -n "def export\|def get_onnx" modelopt/onnx -B 2 -A 5 | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the torch_onnx.py to see how exporters are used
cat -n modelopt/torch/_deploy/utils/torch_onnx.py | grep -A 10 -B 10 "FP8QuantExporter\|n\|onnx_exporters"

Repository: NVIDIA/Model-Optimizer

Length of output: 31963


🏁 Script executed:

# Check the actual file content more carefully
grep -n "class.*Exporter" modelopt/onnx/export/*.py

Repository: NVIDIA/Model-Optimizer

Length of output: 564


🏁 Script executed:

# Look at what might be calling process_model on FP8QuantExporter
rg -n "process_model\|post_process" modelopt/torch

Repository: NVIDIA/Model-Optimizer

Length of output: 48


Add opset >= 19 validation before FP8 Q/DQ conversion in FP8QuantExporter.post_process().

The code converts TRT custom ops to native QuantizeLinear/DequantizeLinear with FLOAT8E4M3FN and the saturate attribute, but does not verify that the model's opset is >= 19 (the minimum required for these operators). When callers invoke get_onnx_bytes_and_metadata() with onnx_opset < 19 on a FP8-quantized model, the post-processor will silently generate an invalid ONNX model instead of upgrading the opset or raising an error.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 121 - 147, In
FP8QuantExporter.post_process(), before converting
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to native
QuantizeLinear/DequantizeLinear using FLOAT8E4M3FN and the saturate attribute,
validate the model opset version is >= 19; locate the method
FP8QuantExporter.post_process and check the graph/model opset (opset_import or
graph.model.opset_import) and if opset < 19 either raise a clear exception
(e.g., ValueError) telling callers to use onnx_opset >= 19 or programmatically
upgrade the model opset to 19 before performing the conversions (and then
proceed with the existing replacement logic for TRT_FP8QuantizeLinear and
TRT_FP8DequantizeLinear).

Comment on lines +576 to +599
onnx_opt_graph = quantize_weights(model, onnx_opt_graph)

if dq_only:
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)

try:
# TODO: Single-precision torch model assumed
param_dtype = next(model.parameters()).dtype
except StopIteration:
param_dtype = torch.float32
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
if is_int4_quantized(model) or is_mxfp8_quantized(model):
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
keep_io_types=False,
disable_shape_infer=True,
check_fp16_ready=False,
)
else:
onnx_opt_graph = convert_to_f16(
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False
)
if weights_dtype == "fp16":
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
keep_io_types=False,
disable_shape_infer=True,
check_fp16_ready=False,
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])

# TensorRT expects all scales to be postive
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)
onnx_opt_graph = remove_redundant_casts(onnx_opt_graph)

# TensorRT expects all scales to be postive
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)

# TODO: Remove manual ir_version change once ORT supports ir_version 11
# Must be set after all gs.export_onnx() calls as graphsurgeon resets ir_version
onnx_opt_graph.ir_version = 10
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Rebuild model_metadata after the final ONNX rewrites.

Lines 572-574 capture metadata before quantize_weights(), convert_float_to_float16(), change_casts_to_fp16(), and remove_redundant_casts(). Those passes add/remove nodes and can rewrite I/O tensors, so the returned metadata can drift from the serialized model bytes.

🛠️ Proposed fix

Move the metadata creation block below the last graph mutation.

-    model_metadata = create_model_metadata(
-        tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model
-    )

Then re-add it after onnx_opt_graph.ir_version = 10:

+    model_metadata = create_model_metadata(
+        tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 576 - 599, The
model_metadata is built too early and can become stale after graph rewrites;
move the metadata creation so it runs after all ONNX mutations: after
quantize_weights(), qdq_to_dq(), convert_float_to_float16(),
change_casts_to_fp16(), remove_redundant_casts(), and
replace_zero_scale_with_smallest_nonzero(), and ensure you rebuild it after
setting onnx_opt_graph.ir_version = 10 so the returned metadata matches the
final serialized graph bytes.

Comment thread modelopt/torch/_deploy/utils/torch_onnx.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/onnx/utils.py (1)

1507-1511: ⚠️ Potential issue | 🟠 Major

Guard FP32→FP16 cast rewrites by input dtype.

At Line 1507, this rewrites any Cast(to=FLOAT) feeding target ops, including casts from non-FP32 sources. That can silently change behavior by removing intentional upcasts.

💡 Proposed fix
     for node in model.graph.node:
         if node.op_type != "Cast":
             continue
@@
         if not feeds_target:
             continue
 
+        cast_input_type = _get_tensor_type_by_name(model, node.input[0])
+        if cast_input_type != onnx.TensorProto.FLOAT:
+            continue
+
         # Check if Cast is to FP32, and change to FP16
         for attr in node.attribute:
             if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
                 attr.i = onnx.TensorProto.FLOAT16
                 break
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

921-923: Prefer public onnx_utils APIs over underscored helpers across modules.

Using _get_tensor_type_by_name, _bypass_cast_node, and _is_same_type_cast from another module couples this class to private implementation details. Exposing public wrappers in modelopt/onnx/utils.py would make this boundary safer.

Also applies to: 942-942, 1097-1102

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/autocast/precisionconverter.py` around lines 921 - 923, The
code currently calls private helpers onnx_utils._get_tensor_type_by_name,
onnx_utils._bypass_cast_node, and onnx_utils._is_same_type_cast from
precisionconverter.py; replace these calls with public wrapper APIs (e.g.,
get_tensor_type_by_name, bypass_cast_node, is_same_type_cast) exported from
modelopt/onnx/utils.py and update precisionconverter.py to call those public
names (also update the other locations that use the underscored helpers around
lines referenced). Add the small public wrapper implementations in
modelopt/onnx/utils.py that delegate to the existing private functions so other
modules use the stable public API and then run tests to ensure behavior is
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/utils.py`:
- Around line 1296-1322: The current _is_sequential_cast only compares the two
Cast target types; modify it to also fetch the data type of the original source
feeding the first Cast (e.g., inspect the producer of node.input[0] or its
ValueInfo/initializer) and verify that this source type equals the second cast's
target type (the value returned by get_cast_to_type(next_node)) before returning
True; this extra check ensures that when _bypass_cast_node rewires the graph the
source type is compatible with the second Cast. Use get_consumer_nodes,
get_cast_to_type and the node.input[0] producer lookup to locate and compare
types.

---

Nitpick comments:
In `@modelopt/onnx/autocast/precisionconverter.py`:
- Around line 921-923: The code currently calls private helpers
onnx_utils._get_tensor_type_by_name, onnx_utils._bypass_cast_node, and
onnx_utils._is_same_type_cast from precisionconverter.py; replace these calls
with public wrapper APIs (e.g., get_tensor_type_by_name, bypass_cast_node,
is_same_type_cast) exported from modelopt/onnx/utils.py and update
precisionconverter.py to call those public names (also update the other
locations that use the underscored helpers around lines referenced). Add the
small public wrapper implementations in modelopt/onnx/utils.py that delegate to
the existing private functions so other modules use the stable public API and
then run tests to ensure behavior is unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 188d9d44-97fd-4011-abb0-37b5f6bdbf27

📥 Commits

Reviewing files that changed from the base of the PR and between 788313fa4a998c8f18315b736310f700b29a6a21 and 40ce80f69a78b1880e57e7b3f685ec8301a14095.

📒 Files selected for processing (5)
  • modelopt/onnx/autocast/graphsanitizer.py
  • modelopt/onnx/autocast/precisionconverter.py
  • modelopt/onnx/autocast/utils.py
  • modelopt/onnx/utils.py
  • tests/unit/onnx/autocast/test_precisionconverter.py

Comment thread modelopt/onnx/utils.py
Comment on lines +1296 to +1322
def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
assert node.op_type == "Cast"
output_type = get_cast_to_type(node)

# Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed
# Cast to low precision -> cast to high precision affects precision and should not be removed
precision_order = [
onnx.TensorProto.DOUBLE,
onnx.TensorProto.FLOAT,
onnx.TensorProto.FLOAT16,
onnx.TensorProto.BFLOAT16,
]
consumers = [n for n in get_consumer_nodes(model, node.output[0]) if n.op_type == "Cast"]

# If the first cast has additional consumers, we should not remove it
if len(consumers) != 1:
return False

next_node = consumers[0]
first_cast_type = output_type
second_cast_type = get_cast_to_type(next_node)

return (
first_cast_type in precision_order
and second_cast_type in precision_order
and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate and read the function
cat -n modelopt/onnx/utils.py | sed -n '1290,1330p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1919


🏁 Script executed:

# Also check if the helper function _get_tensor_type_by_name exists (used in proposed fix)
rg "_get_tensor_type_by_name" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 216


🏁 Script executed:

# Search for get_cast_to_type and get_consumer_nodes to understand them
rg "def get_cast_to_type|def get_consumer_nodes" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 211


🏁 Script executed:

# Check if there are tests for this function
fd "test.*\.py" | xargs rg "_is_sequential_cast" 2>/dev/null || echo "No tests found"

Repository: NVIDIA/Model-Optimizer

Length of output: 80


🏁 Script executed:

# Find where _is_sequential_cast is called
rg "_is_sequential_cast" modelopt/onnx/utils.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 440


🏁 Script executed:

# Look at the context around _bypass_cast_node to understand what happens when removal occurs
cat -n modelopt/onnx/utils.py | sed -n '1325,1360p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2048


🏁 Script executed:

# Search for comments or documentation about cast removal logic
rg "sequential.*cast|cast.*removal|bypass.*cast" modelopt/onnx/utils.py -i -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1472


🏁 Script executed:

# Look for any test files that might test cast optimization
fd "test.*\.py" modelopt/onnx/ | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Get more context about _is_same_type_cast to understand the overall pattern
cat -n modelopt/onnx/utils.py | sed -n '1285,1295p'

Repository: NVIDIA/Model-Optimizer

Length of output: 572


🏁 Script executed:

# Look for the full remove_redundant_casts function to understand the overall flow
rg "def remove_redundant_casts" modelopt/onnx/utils.py -A 30

Repository: NVIDIA/Model-Optimizer

Length of output: 1393


🏁 Script executed:

# Check if there's any documentation or related test about safe cast removal
fd "*.py" modelopt/onnx/ | xargs grep -l "cast\|Cast" | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 349


🏁 Script executed:

# Look at how _get_tensor_type_by_name works and what it can return
rg "def _get_tensor_type_by_name" modelopt/onnx/utils.py -A 20

Repository: NVIDIA/Model-Optimizer

Length of output: 991


🏁 Script executed:

# Check for any tests or usage examples that show what types node.input can have
fd ".py" modelopt/onnx/ --type f | xargs grep -l "remove_redundant_casts\|_is_sequential_cast" 2>/dev/null

Repository: NVIDIA/Model-Optimizer

Length of output: 133


🏁 Script executed:

# Look at the overall structure of the optimization and find if there are any tests
fd ".py" --type f modelopt/onnx/ | xargs grep -l "test.*cast\|cast.*test" 2>/dev/null | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check the precisionconverter file to see if it has any tests or usage
cat -n modelopt/onnx/autocast/precisionconverter.py | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 2246


🏁 Script executed:

# Let me understand the logic better by looking at a concrete example
# First, let me see if there are any examples or documentation about this function
rg "sequential.*cast" modelopt/ -i --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 555


🏁 Script executed:

# Check if the source type is ever checked anywhere in the function
cat -n modelopt/onnx/utils.py | sed -n '1296,1323p' | grep -E "input|src|source"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the broader context to understand how the function comment relates to the logic
cat -n modelopt/onnx/utils.py | sed -n '1296,1323p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1323


🏁 Script executed:

# Let me trace through a concrete scenario to understand the issue better
# Looking at the comment and logic more carefully
cat -n modelopt/onnx/utils.py | sed -n '1296,1325p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1415


🏁 Script executed:

# Understand what happens when we remove the first cast by looking at bypass logic
cat -n modelopt/onnx/utils.py | sed -n '1325,1355p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1685


Sequential-cast removal must validate the original source type against the final cast target.

The function only compares the two cast target types (lines 1318–1322) but doesn't verify that removing the first cast preserves the input type to the second cast. When _bypass_cast_node rewires the graph, it directly connects the original source to the second cast. If the source type differs from what the second cast expects, this changes behavior.

For example:

  • FLOAT16 → cast(FLOAT) → cast(FLOAT16) passes the current check (FLOAT ≤ FLOAT16)
  • After removal: FLOAT16 → cast(FLOAT16) — but cast2 was designed for FLOAT input, causing incorrect behavior

Add a check to ensure the source type matches the second cast's target type before removal.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1296 - 1322, The current
_is_sequential_cast only compares the two Cast target types; modify it to also
fetch the data type of the original source feeding the first Cast (e.g., inspect
the producer of node.input[0] or its ValueInfo/initializer) and verify that this
source type equals the second cast's target type (the value returned by
get_cast_to_type(next_node)) before returning True; this extra check ensures
that when _bypass_cast_node rewires the graph the source type is compatible with
the second Cast. Use get_consumer_nodes, get_cast_to_type and the node.input[0]
producer lookup to locate and compare types.

@ajrasane ajrasane enabled auto-merge (squash) March 13, 2026 17:01
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane force-pushed the ajrasane/onnx_qdq branch from 40ce80f to 05c33b2 Compare March 13, 2026 17:07
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (3)
modelopt/torch/_deploy/utils/torch_onnx.py (2)

576-576: ⚠️ Potential issue | 🟠 Major

Rebuild model_metadata after the last graph rewrite.

Starting with quantize_weights(), this function mutates node names, tensor dtypes, and Q/DQ structure after model_metadata has already been captured above. The returned metadata can therefore describe a different graph than the bytes written to disk.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` at line 576, model_metadata is
captured before quantize_weights mutates the ONNX graph (node names, tensor
dtypes, Q/DQ structure), so the saved metadata can be out of sync with
onnx_opt_graph; after the final graph rewrite (the call to quantize_weights that
returns onnx_opt_graph) re-run the metadata extraction routine (the same
function that produced model_metadata earlier) to rebuild model_metadata from
the mutated onnx_opt_graph so the metadata matches the bytes written to disk;
update any downstream uses to reference the new model_metadata variable produced
after quantize_weights.

581-592: ⚠️ Potential issue | 🟠 Major

The FP16 path is still “autocast + convert_float_to_float16()”.

With this new rewrite block, weights_dtype="fp16" still traces under torch.autocast("cuda") earlier in the function, so FP16 export is not actually using convert_float_to_float16() instead of autocast. That makes the exported graph depend on both mechanisms and undermines the stated pipeline change.

🛠️ Suggested earlier-function change
-    use_torch_autocast = not (
-        is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
-    )
+    use_torch_autocast = weights_dtype == "bf16" and not (
+        is_fp4_quantized(model) or is_mxfp8_quantized(model)
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 581 - 592, The
current FP16 branch still runs under torch.autocast earlier, so the export uses
both autocast and convert_float_to_float16; modify the export flow so when
weights_dtype == "fp16" you do NOT run the model export under
torch.autocast("cuda") (or skip the autocast context) so the graph is produced
in full-precision then transformed only by convert_float_to_float16 and
change_casts_to_fp16; locate the autocast usage earlier in this module
(torch.autocast or the export context manager) and add a conditional to bypass
it when weights_dtype == "fp16", ensuring
convert_float_to_float16/remove_redundant_casts are the sole FP16
transformations.
modelopt/onnx/export/fp8_exporter.py (1)

121-147: ⚠️ Potential issue | 🟠 Major

Reject native FP8 Q/DQ export below opset 19.

This conversion emits native QuantizeLinear/DequantizeLinear with FLOAT8E4M3FN and saturate, but it still doesn't guard against onnx_opset < 19. Exporting FP8 with a lower opset will silently produce an invalid model instead of failing early or upgrading the opset.

🛠️ Suggested guard
     def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
+        opset = next(
+            (op.version for op in onnx_model.opset_import if op.domain in ("", "ai.onnx")),
+            0,
+        )
+        if opset < 19:
+            raise ValueError("Native FP8 ONNX Q/DQ requires ai.onnx opset >= 19.")
+
         logger.info("Post-processing FP8 quantized model")
         graph = gs.import_onnx(onnx_model)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 121 - 147, The conversion
loop for TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear must guard against onnx
opset < 19; detect the model/export opset (e.g., an existing variable or by
examining graph.opset or a passed opset parameter) before performing the
conversions in the loops that change node.op to
"QuantizeLinear"/"DequantizeLinear" and set FLOAT8E4M3FN/saturate, and if the
opset is less than 19 either raise a clear exception or upgrade the opset to >=
19 before modifying nodes (apply this check where you manipulate node.op and
node.attrs in the FP8 exporter function that iterates graph.nodes).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 127-134: The injected zero-point Constant currently uses node.name
which may be empty and cause duplicate tensor names; change the naming for the
Constant created from zp_tensor/zp_values/zero_point to a guaranteed-unique
string (e.g., combine node.name when present with a unique suffix such as a
uuid4 or the node's memory id or an incrementing counter, or use an ONNX/graph
helper that returns a unique name) so each FP8 zero-point Constant has a
distinct tensor name even for unnamed TRT FP8 Q nodes.

In `@modelopt/onnx/utils.py`:
- Around line 1262-1276: The helper _get_tensor_type_by_name must also handle
producer-only tensors emitted by nodes (e.g., a Constant node that produces a
tensor but has no value_info entry); modify _get_tensor_type_by_name to iterate
model.graph.node and when a node.output matches tensor_name, if node.op_type ==
"Constant" extract the TensorProto from the node attribute (attribute named
"value") and return its data_type (or elem_type equivalent), otherwise
skip/continue so non-materialized producer-only tensors do not cause an
exception and allow remove_redundant_casts() to fold Constant->Cast patterns.
- Around line 1403-1417: The two cast-removal checks can both match the same
node causing duplicate removal; make them mutually exclusive by ensuring once a
node is handled by _is_sequential_cast(onnx_model, node) (where you call
_bypass_cast_node and append to nodes_to_remove) you skip the subsequent
_is_foldable_constant_cast_pattern check (e.g., use an elif or continue) so you
only call _bypass_cast_node, _convert_constant_values and append to
nodes_to_remove once; update the block containing _is_sequential_cast,
_is_foldable_constant_cast_pattern, _bypass_cast_node, _convert_constant_values,
get_producer_nodes, nodes_to_remove, and logger.debug accordingly.
- Around line 1499-1511: The current logic uses any(...) on tensor_to_consumers
to set Casts to FP16 even if only one consumer is in target_op_types, which
incorrectly changes shared Cast outputs; modify the check so the Cast is
retargeted only when the entire fanout is eligible (replace the any(...) test
with an all(...) test, and treat empty consumer lists as ineligible/skip), then
keep the existing loop over node.attribute that looks for attr.name == "to" and
change attr.i from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 only when
that all-consumers condition holds.

---

Duplicate comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 121-147: The conversion loop for
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear must guard against onnx opset <
19; detect the model/export opset (e.g., an existing variable or by examining
graph.opset or a passed opset parameter) before performing the conversions in
the loops that change node.op to "QuantizeLinear"/"DequantizeLinear" and set
FLOAT8E4M3FN/saturate, and if the opset is less than 19 either raise a clear
exception or upgrade the opset to >= 19 before modifying nodes (apply this check
where you manipulate node.op and node.attrs in the FP8 exporter function that
iterates graph.nodes).

In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Line 576: model_metadata is captured before quantize_weights mutates the ONNX
graph (node names, tensor dtypes, Q/DQ structure), so the saved metadata can be
out of sync with onnx_opt_graph; after the final graph rewrite (the call to
quantize_weights that returns onnx_opt_graph) re-run the metadata extraction
routine (the same function that produced model_metadata earlier) to rebuild
model_metadata from the mutated onnx_opt_graph so the metadata matches the bytes
written to disk; update any downstream uses to reference the new model_metadata
variable produced after quantize_weights.
- Around line 581-592: The current FP16 branch still runs under torch.autocast
earlier, so the export uses both autocast and convert_float_to_float16; modify
the export flow so when weights_dtype == "fp16" you do NOT run the model export
under torch.autocast("cuda") (or skip the autocast context) so the graph is
produced in full-precision then transformed only by convert_float_to_float16 and
change_casts_to_fp16; locate the autocast usage earlier in this module
(torch.autocast or the export context manager) and add a conditional to bypass
it when weights_dtype == "fp16", ensuring
convert_float_to_float16/remove_redundant_casts are the sole FP16
transformations.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7704d15e-7bdf-4b56-b725-8f52b69a07e2

📥 Commits

Reviewing files that changed from the base of the PR and between 40ce80f69a78b1880e57e7b3f685ec8301a14095 and 05c33b2.

📒 Files selected for processing (9)
  • CHANGELOG.rst
  • modelopt/onnx/autocast/graphsanitizer.py
  • modelopt/onnx/autocast/precisionconverter.py
  • modelopt/onnx/autocast/utils.py
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/export/nvfp4_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • tests/unit/onnx/autocast/test_precisionconverter.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/onnx/autocast/test_precisionconverter.py

Comment on lines +127 to +134
# Create FP8 zero point constant
zp_tensor = onnx.TensorProto()
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
zp_tensor.dims.extend([1]) # 1-element tensor
zp_tensor.raw_data = b"\x00" # Zero in FP8
zp_values = LazyValues(zp_tensor)
zero_point = gs.Constant(node.name + "_zero_point", zp_values)
node.inputs.append(zero_point)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use a guaranteed-unique tensor name for the injected zero point.

node.name is optional in ONNX, so node.name + "_zero_point" can collapse to the same tensor name for multiple unnamed TRT FP8 Q nodes. That can make the exported graph invalid due to duplicate tensor names.

🛠️ Safer naming
-                    zero_point = gs.Constant(node.name + "_zero_point", zp_values)
+                    zero_point = gs.Constant(f"{node.outputs[0].name}_zero_point", zp_values)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 127 - 134, The injected
zero-point Constant currently uses node.name which may be empty and cause
duplicate tensor names; change the naming for the Constant created from
zp_tensor/zp_values/zero_point to a guaranteed-unique string (e.g., combine
node.name when present with a unique suffix such as a uuid4 or the node's memory
id or an incrementing counter, or use an ONNX/graph helper that returns a unique
name) so each FP8 zero-point Constant has a distinct tensor name even for
unnamed TRT FP8 Q nodes.

Comment thread modelopt/onnx/utils.py Outdated
Comment thread modelopt/onnx/utils.py
Comment thread modelopt/onnx/utils.py Outdated
@cjluo-nv
Copy link
Copy Markdown
Collaborator

Review Comments

Thanks for the PR — the core idea of replacing TRT-specific FP8 QDQ nodes with native ONNX ops is solid, and the refactoring to centralize cast utilities in onnx/utils.py makes sense. However, there are several concerns that I think should be addressed before merging.

1. BF16 regression in _convert_constant_values (High)

The original PrecisionConverter._convert_constant_values had special handling for bfloat16:

  • Reading: used read_f16_tensor_as_fp32() for bf16 input tensors
  • Writing: manually created TensorProto with raw bytes for bf16 output

The new onnx_utils._convert_constant_values (modelopt/onnx/utils.py) uses onnx.numpy_helper.to_array() / from_array() for all types, which doesn't handle bfloat16 natively. This could silently break bf16 constant folding — exactly the kind of AutoCast regression @galagam flagged.

2. BF16 path dropped in get_onnx_bytes_and_metadata (High)

The old code in torch_onnx.py handled weights_dtype in ["fp16", "bf16"] and used convert_to_f16() for the bf16 path. The new code only handles weights_dtype == "fp16". If weights_dtype == "bf16", no FP16/BF16 conversion happens at all. This looks like a silent behavioral regression.

3. Monkey-patching onnxconverter_common with suppress(AttributeError) (Medium)

The patch at torch_onnx.py:59-65 silently swallows all AttributeError exceptions from remove_unnecessary_cast_node. This could mask real bugs. Could you:

  • Add a comment explaining the specific upstream bug this works around?
  • Add a TODO with a link to the upstream issue (if one exists) so this can be removed when fixed?
  • Consider catching more narrowly if possible?

4. quantize_weights now called unconditionally (Medium)

In torch_onnx.py, the guard if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model) was removed. quantize_weights is now called for all models regardless of quantization mode. Can you confirm this is safe / a no-op for non-quantized models?

5. _get_tensor_type_by_name performance and scope change (Low)

The original PrecisionConverter._get_tensor_type used pre-built O(1) dict lookups (self.value_info_map, self.initializer_map). The new _get_tensor_type_by_name does a linear scan over value_info, initializer, input, and output lists on every call, and also searches graph inputs/outputs which the original didn't. For large models this could be noticeably slower. Worth considering caching or at least noting the tradeoff.

6. Minor: _is_same_type_castinput_types is not None is always True

input_types is a list comprehension result, so it's always a list (never None). The check and input_types is not None is a no-op. Copied from the original code but worth cleaning up while you're here.


Items 1, 2, and 4 are the ones I'd like to see addressed or explicitly justified before approving. The rest are suggestions. Thanks!

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
modelopt/onnx/utils.py (3)

1420-1435: ⚠️ Potential issue | 🟠 Major

Missing continue after _is_sequential_cast can cause double processing.

A node matching both _is_sequential_cast and _is_foldable_constant_cast_pattern (e.g., Constant → Cast → Cast) will be processed by both blocks, leading to duplicate entries in nodes_to_remove and potential errors during removal.

🛠️ Proposed fix
             # Find sequential casts that don't change precision
             if _is_sequential_cast(onnx_model, node):
                 nodes_to_remove.append(node)
                 _bypass_cast_node(onnx_model, node)
                 logger.debug(f"Found removable double-cast: {node.name}")
+                continue

             # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers.
             if _is_foldable_constant_cast_pattern(onnx_model, node):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1420 - 1435, The sequential-cast branch
(_is_sequential_cast) can fall through and be reprocessed by the foldable
Constant->Cast branch, causing duplicate entries in nodes_to_remove and
double-modification of the same node; after handling a sequential cast (where
you call _bypass_cast_node and log), add a control-flow break (e.g., a continue)
to skip further checks for that node so it isn't also processed by
_is_foldable_constant_cast_pattern, ensuring nodes_to_remove, _bypass_cast_node,
_convert_constant_values, get_producer_nodes and the node aren't acted on twice.

1516-1528: ⚠️ Potential issue | 🟠 Major

Using any() can incorrectly change shared Cast outputs affecting non-target consumers.

When a Cast node's output feeds multiple consumers and only some are in target_op_types, changing the Cast to FP16 affects all consumers, potentially breaking non-target branches.

🛠️ Proposed fix - only change if ALL consumers are target ops
         # Check if this Cast outputs to a target op type
         cast_output = node.output[0]
         consumers = tensor_to_consumers.get(cast_output, [])
-        feeds_target = any(c.op_type in target_op_types for c in consumers)
+        feeds_target = bool(consumers) and all(c.op_type in target_op_types for c in consumers)

         if not feeds_target:
             continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1516 - 1528, The current logic uses
any() and flips a Cast to FP16 even when some consumers are non-target, which
can break those branches; change the feeds_target condition to only be true when
there is at least one consumer and ALL consumers are in target_op_types (i.e.,
replace feeds_target = any(...) with feeds_target = bool(consumers) and
all(c.op_type in target_op_types for c in consumers)), then proceed with the
existing node.attribute loop that checks attr.name == "to" and attr.i to change
FLOAT to FLOAT16.

1262-1277: ⚠️ Potential issue | 🟡 Minor

Function does not handle Constant node outputs that lack value_info entries.

When a Constant node produces an output tensor that isn't registered in value_info, initializer, input, or output, this function raises an exception. This can occur before _is_foldable_constant_cast_pattern gets a chance to handle it.

Consider checking node outputs for Constant nodes:

🛠️ Proposed fix
 def _get_tensor_type_by_name(model: onnx.ModelProto, tensor_name: str):
     """Get the tensor element type. Searches value_info, initializers, inputs, and outputs."""
     for vi in model.graph.value_info:
         if vi.name == tensor_name:
             return vi.type.tensor_type.elem_type
     for init in model.graph.initializer:
         if init.name == tensor_name:
             return init.data_type
     for inp in model.graph.input:
         if inp.name == tensor_name:
             return inp.type.tensor_type.elem_type
     for out in model.graph.output:
         if out.name == tensor_name:
             return out.type.tensor_type.elem_type
+    # Check Constant node outputs
+    for node in model.graph.node:
+        if node.op_type == "Constant" and tensor_name in node.output:
+            for attr in node.attribute:
+                if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR:
+                    return attr.t.data_type
+            break
     raise Exception(f"did not find tensor {tensor_name}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1262 - 1277, _get_tensor_type_by_name
currently misses Constant node outputs that aren't listed in
value_info/initializer/input/output; update it to also scan model.graph.node for
nodes with op_type == "Constant" whose output list contains tensor_name, then
find the node AttributeProto with name "value" (the TensorProto stored on the
Constant) and return its data_type (attr.t.data_type). Keep the existing checks
first (value_info/initializer/input/output), add the Constant-node check before
raising the Exception, and otherwise keep the same error behavior.
🧹 Nitpick comments (2)
modelopt/onnx/utils.py (1)

1289-1294: Redundant None check on list comprehension result.

input_types is always a list (from the list comprehension), so input_types is not None is always True. This check can be simplified.

♻️ Proposed fix
 def _is_same_type_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
     assert node.op_type == "Cast"
     input_types = [_get_tensor_type_by_name(model, inp) for inp in node.input]
     output_type = get_cast_to_type(node)
-    return all(inp_type == output_type for inp_type in input_types) and input_types is not None
+    return bool(input_types) and all(inp_type == output_type for inp_type in input_types)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1289 - 1294, The function
_is_same_type_cast has a redundant "input_types is not None" check because
input_types is always a list from the list comprehension; update the return
statement in _is_same_type_cast (referencing variables input_types and
output_type) to simply return all(inp_type == output_type for inp_type in
input_types), removing the unnecessary None check.
modelopt/torch/_deploy/utils/torch_onnx.py (1)

573-576: model_metadata created before graph mutations may have stale onnx_node_names.

Metadata is captured at line 573-575 before quantize_weights, FP16 conversion, and remove_redundant_casts. While I/O names and shapes should remain stable, onnx_node_names will be stale after remove_redundant_casts removes Cast nodes.

If onnx_node_names is used for validation or debugging, consider moving metadata creation after line 606 (after all mutations).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 573 - 576, The code
creates model_metadata via create_model_metadata(...) before performing graph
mutations, which can leave onnx_node_names stale; move the model_metadata =
create_model_metadata(tree_spec_input, tree_spec_output, input_none_names,
onnx_opt_graph, model) call to after the transformations (after
quantize_weights, FP16 conversion/convert_model_to_fp16, and
remove_redundant_casts) so that onnx_opt_graph reflects the final node set;
ensure you reference the same onnx_opt_graph and model when recreating metadata
so onnx_node_names are accurate for validation/debugging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/utils.py`:
- Around line 1296-1323: In _is_sequential_cast, add validation of the original
source tensor type before approving removal: fetch the source type via
_get_tensor_type_by_name(model, node.input[0]) and compute its index in the
existing precision_order; then ensure source_index <= index(first_cast_type) <=
index(second_cast_type) (using get_cast_to_type for first/second casts and
get_consumer_nodes to find the next Cast) and only return True when this
three-way ordering holds; if the source type cannot be resolved or any type not
in precision_order, return False.

In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 63-73: Add a detailed comment above the monkey-patch for
_f16_module.remove_unnecessary_cast_node explaining the upstream bug being
worked around: reference the upstream issue/PR URL (or issue number) that
describes the "downstream_node is a list" failure, explicitly state the exact
AttributeError message being suppressed (e.g., "'list' object has no attribute
'some_attr'") and under what conditions it occurs, and document the minimum
onnxconverter_common version (or commit/PR) that fixes it so this workaround can
be removed; include the unique symbols _original_remove_unnecessary_cast_node,
_patched_remove_unnecessary_cast_node, and suppress(AttributeError) in the
comment so readers can quickly locate the patched code.

---

Duplicate comments:
In `@modelopt/onnx/utils.py`:
- Around line 1420-1435: The sequential-cast branch (_is_sequential_cast) can
fall through and be reprocessed by the foldable Constant->Cast branch, causing
duplicate entries in nodes_to_remove and double-modification of the same node;
after handling a sequential cast (where you call _bypass_cast_node and log), add
a control-flow break (e.g., a continue) to skip further checks for that node so
it isn't also processed by _is_foldable_constant_cast_pattern, ensuring
nodes_to_remove, _bypass_cast_node, _convert_constant_values, get_producer_nodes
and the node aren't acted on twice.
- Around line 1516-1528: The current logic uses any() and flips a Cast to FP16
even when some consumers are non-target, which can break those branches; change
the feeds_target condition to only be true when there is at least one consumer
and ALL consumers are in target_op_types (i.e., replace feeds_target = any(...)
with feeds_target = bool(consumers) and all(c.op_type in target_op_types for c
in consumers)), then proceed with the existing node.attribute loop that checks
attr.name == "to" and attr.i to change FLOAT to FLOAT16.
- Around line 1262-1277: _get_tensor_type_by_name currently misses Constant node
outputs that aren't listed in value_info/initializer/input/output; update it to
also scan model.graph.node for nodes with op_type == "Constant" whose output
list contains tensor_name, then find the node AttributeProto with name "value"
(the TensorProto stored on the Constant) and return its data_type
(attr.t.data_type). Keep the existing checks first
(value_info/initializer/input/output), add the Constant-node check before
raising the Exception, and otherwise keep the same error behavior.

---

Nitpick comments:
In `@modelopt/onnx/utils.py`:
- Around line 1289-1294: The function _is_same_type_cast has a redundant
"input_types is not None" check because input_types is always a list from the
list comprehension; update the return statement in _is_same_type_cast
(referencing variables input_types and output_type) to simply return
all(inp_type == output_type for inp_type in input_types), removing the
unnecessary None check.

In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 573-576: The code creates model_metadata via
create_model_metadata(...) before performing graph mutations, which can leave
onnx_node_names stale; move the model_metadata =
create_model_metadata(tree_spec_input, tree_spec_output, input_none_names,
onnx_opt_graph, model) call to after the transformations (after
quantize_weights, FP16 conversion/convert_model_to_fp16, and
remove_redundant_casts) so that onnx_opt_graph reflects the final node set;
ensure you reference the same onnx_opt_graph and model when recreating metadata
so onnx_node_names are accurate for validation/debugging.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ba926212-e494-4ed1-a553-8657733f661f

📥 Commits

Reviewing files that changed from the base of the PR and between 05c33b2 and f6ce7b3.

📒 Files selected for processing (2)
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py

Comment thread modelopt/onnx/utils.py
Comment on lines +1296 to +1323
def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
assert node.op_type == "Cast"
output_type = get_cast_to_type(node)

# Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed
# Cast to low precision -> cast to high precision affects precision and should not be removed
precision_order = [
onnx.TensorProto.DOUBLE,
onnx.TensorProto.FLOAT,
onnx.TensorProto.FLOAT16,
onnx.TensorProto.BFLOAT16,
]
consumers = [n for n in get_consumer_nodes(model, node.output[0]) if n.op_type == "Cast"]

# If the first cast has additional consumers, we should not remove it
if len(consumers) != 1:
return False

next_node = consumers[0]
first_cast_type = output_type
second_cast_type = get_cast_to_type(next_node)

return (
first_cast_type in precision_order
and second_cast_type in precision_order
and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type)
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there are unit tests covering _is_sequential_cast edge cases
rg -l "_is_sequential_cast|remove_redundant_casts" --glob "*test*.py"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

#!/bin/bash
# Search for calls to _is_sequential_cast and understand its usage context
rg "_is_sequential_cast" -B 5 -A 5 modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1187


🏁 Script executed:

#!/bin/bash
# Look at how node inputs are accessed in the file
rg "node\.input\[0\]" modelopt/onnx/utils.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 239


🏁 Script executed:

#!/bin/bash
# Check if there are helper functions to get input nodes or types
rg "def get_input|def.*input" modelopt/onnx/utils.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1040


🏁 Script executed:

#!/bin/bash
# Look at helper function _get_tensor_type_by_name used in the file
rg "_get_tensor_type_by_name" modelopt/onnx/utils.py -A 10 | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 1166


🏁 Script executed:

#!/bin/bash
# Check the full context of _is_sequential_cast including surrounding functions
sed -n '1280,1330p' modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2090


🏁 Script executed:

#!/bin/bash
# Search for get_producer_nodes to understand how to get source node info
rg "def get_producer_nodes" modelopt/onnx/utils.py -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 325


🏁 Script executed:

#!/bin/bash
# Get the exact signature and full implementation of _is_sequential_cast to confirm
sed -n '1296,1323p' modelopt/onnx/utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1323


🏁 Script executed:

#!/bin/bash
# Check if there are any recent commits or comments about source type validation
git log --oneline --all -S "source_type" -- modelopt/onnx/utils.py | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


Add source type validation to prevent incorrect cast optimization.

The function only compares the first and second cast output types but ignores the original source type. This creates a semantic bug where incorrect optimizations can occur.

Example: If source is BFLOAT16 → Cast to DOUBLE → Cast to FLOAT16, the function currently approves removal of the first cast (since DOUBLE index ≤ FLOAT16 index), resulting in a direct BFLOAT16 → FLOAT16 cast. This differs from the original two-cast path through higher precision.

Fix: Retrieve the source type using _get_tensor_type_by_name(model, node.input[0]) and validate that source_type_index <= first_cast_type_index <= second_cast_type_index before approving removal.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1296 - 1323, In _is_sequential_cast, add
validation of the original source tensor type before approving removal: fetch
the source type via _get_tensor_type_by_name(model, node.input[0]) and compute
its index in the existing precision_order; then ensure source_index <=
index(first_cast_type) <= index(second_cast_type) (using get_cast_to_type for
first/second casts and get_consumer_nodes to find the next Cast) and only return
True when this three-way ordering holds; if the source type cannot be resolved
or any type not in precision_order, return False.

Comment thread modelopt/torch/_deploy/utils/torch_onnx.py Outdated
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane force-pushed the ajrasane/onnx_qdq branch from 4dda7bf to 9e3a35a Compare March 16, 2026 18:41
@ajrasane
Copy link
Copy Markdown
Contributor Author

@cjluo-nv,
I have added fixes for 1,2,3,5 and 6
For 4, the code internally checks is the respective quantizers are present in the model, hence it also works for non-quantized model. So I haven't made any changes.

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review

The refactoring direction (moving graph utilities to onnx/utils.py) is sound, the _build_tensor_type_map O(1) optimization is nice, and accuracy results show negligible regression. A few correctness issues to address before approving:

1. Bug: Missing continue in remove_redundant_casts() (onnx/utils.py)

The _is_sequential_cast and _is_foldable_constant_cast_pattern checks are not mutually exclusive. A Constant -> Cast -> Cast pattern satisfies both. Without a continue after the sequential-cast branch, _bypass_cast_node gets called twice on the same node and it gets appended to nodes_to_remove twice. The second bypass operates on already-modified graph connections, which could corrupt the graph.

# After this block, add `continue`:
if _is_sequential_cast(onnx_model, node):
    nodes_to_remove.append(node)
    _bypass_cast_node(onnx_model, node)
    logger.debug(f"Found removable double-cast: {node.name}")
    continue  # <-- missing

2. _get_tensor_type_by_name() can throw on Constant-produced tensors

This helper only searches value_info, initializers, inputs, and outputs. But Constant node outputs are often not materialized in value_info. When _is_same_type_cast calls this on a Cast fed by a Constant, it will raise Exception("did not find tensor ..."). Consider also scanning Constant node attributes in _build_tensor_type_map, or handling the exception gracefully in _is_same_type_cast.

3. change_casts_to_fp16() is overly broad

  • Uses any() — if a Cast output feeds both a Concat and a non-target op, the Cast is flipped to FP16 for all consumers, potentially breaking the non-target branch. Consider using all() or only retargeting when the entire fanout is eligible.
  • Doesn't verify source type — a Cast from FP64→FP32 would get incorrectly changed to FP64→FP16.

4. quantize_weights() is now unconditional (torch_onnx.py)

Previously gated by is_int4_quantized or is_fp4_quantized or is_mxfp8_quantized. Does quantize_weights() no-op safely for non-quantized models?

5. Minor: Zero-point tensor naming (fp8_exporter.py)

node.name + "_zero_point" could produce duplicate names if multiple TRT_FP8QuantizeLinear nodes have empty names. Using node.output[0] as the name root would be safer.

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 requested a review from cjluo-nv March 17, 2026 04:38
@kevalmorabia97 kevalmorabia97 disabled auto-merge March 17, 2026 05:43
@kevalmorabia97 kevalmorabia97 merged commit e4df91b into main Mar 17, 2026
52 of 54 checks passed
@kevalmorabia97 kevalmorabia97 deleted the ajrasane/onnx_qdq branch March 17, 2026 05:44
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — the latest commit addresses the key review items (missing continue, Constant tensor type lookup, change_casts_to_fp16 scoping, and quantize_weights no-op guard).

Two minor nits for follow-up:

  • fp8_exporter.py: node.name + "_zero_point" could produce duplicate names if nodes have empty names — consider using node.output[0] as the name root instead.
  • torch_onnx.py: bare print("No quantization exporters found...") — consider logger.info() for consistency.

Neither is blocking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants